package com.template.tree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * 红黑树
 *
 * 1. 每个节点或红色或黑色
 * 2. 根节点是黑色的
 * 3. 所有叶子节点是黑色的（叶子节点为 null 的节点）
 * 4. 红色节点的两个子节点都是黑色的
 * 5. 从任意一个节点到其每个叶子节点的所有路径都包含相同数量的黑色节点
 */
public class RedBlackTree<T extends Comparable<T>> {

    private static final boolean RED = true; // 红色节点
    private static final boolean BLACK = false; // 黑色节点

    /**
     * 树节点
     */
    private class Node {
        T value; // 存储的值
        Node left, right; // 左右子节点
        boolean color; // 节点颜色

        public Node(T value) {
            this.value = value;
            color = RED; // 新增节点默认为红色
        }
    }

    private Node root; // 根节点

    /**
     * 判断节点是否为红色
     *
     * @param node 节点
     * @return 是否为红色
     */
    private boolean isRed(Node node) {
        if (node == null)
            return false;
        return node.color == RED;
    }

    /**
     * 左旋转
     *
     * @param node 原本的根节点
     * @return 新的根节点
     */
    private Node rotateLeft(Node node) {
        Node x = node.right;
        node.right = x.left;
        x.left = node;
        x.color = node.color;
        node.color = RED;
        return x;
    }

    /**
     * 右旋转
     *
     * @param node 原本的根节点
     * @return 新的根节点
     */
    private Node rotateRight(Node node) {
        Node x = node.left;
        node.left = x.right;
        x.right = node;
        x.color = node.color;
        node.color = RED;
        return x;
    }

    /**
     * 颜色翻转
     *
     * @param node 要翻转的节点
     */
    private void flipColors(Node node) {
        node.color = RED;
        node.left.color = BLACK;
        node.right.color = BLACK;
    }

    /**
     * 在树中插入一个元素
     *
     * @param value 要插入的元素值
     */
    public void insert(T value) {
        root = insert(root, value);
        root.color = BLACK; // 根节点是黑色
    }

    /**
     * 递归插入
     *
     * @param node  当前节点
     * @param value 要插入的元素值
     * @return 新的节点
     */
    private Node insert(Node node, T value) {
        if (node == null)
            return new Node(value);

        int cmp = value.compareTo(node.value);
        if (cmp < 0)
            node.left = insert(node.left, value);
        else if (cmp > 0)
            node.right = insert(node.right, value);
        else
            node.value = value;

        // 修复红黑树性质
        /**
         * 通过左旋转和右旋转操作，将红色节点移到左侧，以符合性质 4
         * 通过左旋转和右旋转操作，解决出现连续两个红色节点的问题，以符合性质 4
         * 通过颜色翻转操作，解决出现红父节点和黑子节点之间夹了一个红节点的问题，以符合性质 5
         */
        if (isRed(node.right) && !isRed(node.left))
            node = rotateLeft(node);
        if (isRed(node.left) && isRed(node.left.left))
            node = rotateRight(node);
        if (isRed(node.left) && isRed(node.right))
            flipColors(node);

        return node;
    }

    /**
     * 从树中删除一个元素
     *
     * @param value 要删除的元素值
     */
    public void remove(T value) {
        root = remove(root, value);
        if (root != null) {
            root.color = BLACK; // 根节点是黑色
        }
    }

    /**
     * 递归删除
     *
     * @param node  当前节点
     * @param value 要删除的元素值
     * @return 新的节点
     */
    private Node remove(Node node, T value) {
        if (node == null) {
            return null;
        }

        int cmp = value.compareTo(node.value);
        if (cmp < 0) {
            node.left = remove(node.left, value);
        } else if (cmp > 0) {
            node.right = remove(node.right, value);
        } else {
            if (node.left == null) {
                return node.right;
            } else if (node.right == null) {
                return node.left;
            } else {
                Node temp = min(node.right);
                node.value = temp.value;
                node.right = remove(node.right, temp.value);
            }
        }

        // 修复红黑树性质
        if (isRed(node.right) && !isRed(node.left))
            node = rotateLeft(node);
        if (isRed(node.left) && isRed(node.left.left))
            node = rotateRight(node);
        if (isRed(node.left) && isRed(node.right))
            flipColors(node);

        return node;
    }

    /**
     * 查找子树中的最小节点
     *
     * @param node 子树的根节点
     * @return 最小节点
     */
    private Node min(Node node) {
        while (node.left != null) {
            node = node.left;
        }
        return node;
    }


    /**
     * ===================================实现红黑树的金字塔结构打印===================================================
     */
    private class OutputNode {
        T value;
        boolean color;
        int x; // 节点在控制台中的横坐标
        int y; // 节点在控制台中的纵坐标

        public OutputNode(T value, boolean color) {
            this.value = value;
            this.color = color;
        }
    }


    public void printTree() {
        List<OutputNode> outputNodes = new ArrayList<>();
        calculatePosition(root, outputNodes, 0, 0);
        printNodes(outputNodes);
    }

    private void calculatePosition(Node node, List<OutputNode> outputNodes, int x, int y) {
        if (node != null) {
            OutputNode outputNode = new OutputNode(node.value, node.color);
            outputNode.x = x;
            outputNode.y = y;
            outputNodes.add(outputNode);
            calculatePosition(node.left, outputNodes, x - 2, y + 1);
            calculatePosition(node.right, outputNodes, x + 2, y + 1);
        }
    }

    private void printNodes(List<OutputNode> outputNodes) {
        int minX = 0, maxX = 0;
        int maxY = 0;
        for (OutputNode node : outputNodes) {
            if (node.x < minX) {
                minX = node.x;
            }
            if (node.x > maxX) {
                maxX = node.x;
            }
            if (node.y > maxY) {
                maxY = node.y;
            }
        }
        int rows = maxY + 1;
        int cols = (maxX - minX + 1) * 2 - 1;
        String[][] matrix = new String[rows][cols];
        for (String[] row : matrix) {
            Arrays.fill(row, " ");
        }
        for (OutputNode node : outputNodes) {
            int x = (node.x - minX) * 2;
            int y = node.y;
            matrix[y][x] = String.valueOf(node.value);
            if (node.color == RED) {
                matrix[y][x] = "\033[31m" + matrix[y][x] + "\033[0m"; // 红色字符
            }
        }
        for (int i = rows - 1; i >= 0; i--) { // 从下往上逐行输出
            StringBuilder sb = new StringBuilder();
            for (String s : matrix[i]) {
                sb.append(s).append(" ");
            }
            System.out.println(sb.toString());
        }
    }

}
